import torch
import torch.nn as nn
import torch.nn.functional as F

class LabelSmoothEntropy(nn.Module):
    def __init__(self, smooth=0.1, class_weights=None, size_average='mean'):
        super(LabelSmoothEntropy, self).__init__()
        self.size_average = size_average
        self.smooth = smooth
        self.class_weights = class_weights

    def forward(self, preds, targets):
        lb_pos = 1 - self.smooth
        lb_neg = self.smooth / (preds.shape[0] - 1)
        
        smoothed_lb = torch.zeros_like(preds).fill_(lb_neg)
        smoothed_lb.scatter_(1, targets[:, None], lb_pos)
        
        log_soft = F.log_softmax(preds, dim=1)
        
        if self.class_weights is not None:
            loss = -log_soft * smoothed_lb * self.class_weights[None, :]
        else:
            loss = -log_soft * smoothed_lb
            
        loss = loss.sum(1)
        
        if self.size_average == 'mean':
            return loss.mean()
        elif self.size_average == 'sum':
            return loss.sum()
        else:
            raise NotImplementedError 